{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch,gc\n",
    "from torch import nn\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "from torchvision import transforms as tfs\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import os\n",
    "from tqdm import notebook\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'transform1' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-7-f987a47d2267>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     15\u001b[0m     \u001b[0mimg\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mImage\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrootdir\u001b[0m\u001b[1;33m+\u001b[0m\u001b[0mimagelist\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconvert\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"RGB\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     16\u001b[0m     \u001b[1;31m# 对图片进行基本的处理\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 17\u001b[1;33m     \u001b[0mimg\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtransform1\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     18\u001b[0m     \u001b[0mimg\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtransform2\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     19\u001b[0m     \u001b[0mImg\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mcut\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mNameError\u001b[0m: name 'transform1' is not defined"
     ]
    }
   ],
   "source": [
    "# 获取对应文件夹中所有照片的文件名\n",
    "imagelist = os.listdir('data\\\\puzzle_2x2\\\\test\\\\')\n",
    "num=9000\n",
    "#读入num张图片储存在Img tensor 里 \n",
    "Img=torch.FloatTensor(num,4,3,100,100)\n",
    "# 创建Labels存储位置向量\n",
    "Labels=torch.LongTensor(num,4)\n",
    "label = pd.read_csv('E:\\\\jupyter\\\\homework\\\\homework_report_jigsaw_puzzle\\\\data\\\\puzzle_2x2\\\\test.csv',header=0).iloc[:num,2]\n",
    "#数据预处理进行归一标准化，并转化为tensor\n",
    "\n",
    "# 将图片数据批量读入，这里会占用较大内存\n",
    "rootdir = 'data\\\\puzzle_2x2\\\\test\\\\'\n",
    "for i in range(num):\n",
    "    # 以RGB格式读入图片\n",
    "    img = Image.open(rootdir+imagelist[i]).convert(\"RGB\")\n",
    "    # 对图片进行基本的处理\n",
    "    img = transform1(img)\n",
    "    img = transform2(img)\n",
    "    Img[i]=cut(img)\n",
    "    Labels[i]=torch.LongTensor(list(map(int,label[i].split())))\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform1 = tfs.Compose([tfs.ToTensor()])\n",
    "transform2 = tfs.Compose([tfs.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])\n",
    "def cut(x):\n",
    "    x0 = x[ :, 0:100, 0:100]\n",
    "    #x0 = transform2(x0)\n",
    "    x0 = x0.unsqueeze(0)\n",
    "    x1 = x[ :, 100:200, 0:100]\n",
    "    #x1 = transform2(x1)\n",
    "    x1 = x1.unsqueeze(0)\n",
    "    x2 = x[ :, 0:100, 100:200]\n",
    "    #x2 = transform2(x2)\n",
    "    x2 = x2.unsqueeze(0)\n",
    "    x3 = x[ :, 100:200, 100:200]\n",
    "    #x3 = transform2(x3)\n",
    "    x3 = x3.unsqueeze(0)\n",
    "    x = torch.cat([x0, x1, x2, x3],dim=0 )\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 网络设计 输处一个4*4矩阵行表示对应的那一块图片，列最大为对应的位置\n",
    "\n",
    "class JigsawNet(nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        # input (4*batch)*3*100*100\n",
    "        self.pad = nn.ZeroPad2d(padding=(2, 2, 2, 2))\n",
    "        self.conv1 = nn.Sequential(\n",
    "            nn.Conv2d(3, 50, (5, 5), padding=(2, 2), stride=(2, 2)),\n",
    "            nn.BatchNorm2d(50),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(2, stride=2, padding=0)\n",
    "\n",
    "        )\n",
    "        # b=4, c=50 ,h=26,w=26\n",
    "        self.conv2 = nn.Sequential(\n",
    "            nn.Conv2d(50, 100, (5, 5), padding=(2, 2), stride=(2, 2)),\n",
    "            nn.BatchNorm2d(100),\n",
    "            nn.ReLU(),\n",
    "           \n",
    "        )\n",
    "        # b=4, c=100, h=13,w=13\n",
    "        self.conv3 = nn.Sequential(\n",
    "            nn.Conv2d(100, 100, (3, 3), padding=(2, 2), stride=(2, 2)),\n",
    "            nn.BatchNorm2d(100),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(p=0.3),\n",
    "           \n",
    "        )\n",
    "\n",
    "        # b=4, c=100, h=8, w=8\n",
    "        self.conv4 = nn.Sequential(\n",
    "            nn.Conv2d(100, 200, (3, 3), padding=(2, 2), stride=(1, 1)),\n",
    "            nn.BatchNorm2d(200),\n",
    "            nn.ReLU(),  \n",
    "            nn.Dropout(p=0.3),\n",
    "            # 传入全连接层时进行一维化，默认从第二维开始，第一维是batch\n",
    "            nn.Flatten(start_dim=1)\n",
    "        )\n",
    "        # 4*20000\n",
    "        self.fn = nn.Sequential(\n",
    "            nn.Linear(20000, 600),\n",
    "            nn.BatchNorm1d(600),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(600, 400),\n",
    "            nn.BatchNorm1d(400),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(p=0.3),\n",
    "            nn.Linear(400, 4),\n",
    "\n",
    "        )\n",
    "\n",
    "    def forward(self, x): \n",
    "        # reshape将batch和图片数量合并变成四维\n",
    "        x = x.reshape((-1, 3, 100, 100))\n",
    "        x = self.pad(x)\n",
    "        x = self.conv1(x)\n",
    "        x = self.conv2(x)\n",
    "        x = self.conv3(x)\n",
    "        x = self.conv4(x)\n",
    "        x = self.fn(x)\n",
    "        #x = F.log_softmax(x,dim=1)\n",
    "        return x\n",
    "   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 重写dataloader的类方法将自己的数据传入\n",
    "class My_dataset(Dataset):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.x = Img\n",
    "        self.y = Labels\n",
    "        self.src,  self.trg = [], []\n",
    "        for i in range(len(Img)):\n",
    "            self.src.append(self.x[i])\n",
    "            self.trg.append(self.y[i])\n",
    "           \n",
    "    def __getitem__(self, index):\n",
    "        return self.src[index], self.trg[index]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.src) \n",
    "        \n",
    " # 或者return len(self.trg), src和trg长度一样\n",
    " \n",
    "data_train = My_dataset()\n",
    "data_loader_train = DataLoader(data_train, batch_size=90, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(data_loader,loss_count):\n",
    "    with notebook.tqdm(total=len(data_loader), desc=\"Epoch:{}\".format(epoch + 1), unit='g', leave=False) as pbar:\n",
    "        for i,data in enumerate(data_loader,0):\n",
    "            if use_gpu:\n",
    "                X = data[0].float().cuda()\n",
    "                Y = data[1].cuda()\n",
    "            else:\n",
    "                X = data[0].float()\n",
    "                Y = data[1]\n",
    "            optimizer.zero_grad()\n",
    "            out = model(Variable(X))\n",
    "            # 将Label 转化为4*1与out的4*4进行损失计算\n",
    "            Y = Variable(torch.flatten(Y))\n",
    "            #Y = Y.reshape(-1,1)   \n",
    "            #print(out.shape,Y.shape)\n",
    "            loss = loss_func(out,Y.long())\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            # update the bar\n",
    "            pbar.update(1)\n",
    "            accuracy = torch.max(out.cpu(), 1)[1].numpy() == Y.cpu().numpy()\n",
    "            #accuracy = (torch.max(out,1)[1]==Y).sum()/len(Y)\n",
    "    print('epoch' + str(epoch+1) + ' accuracy:\\t', accuracy.mean())\n",
    "    loss_count.append(loss)\n",
    "    acc_count.append(accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "gc.collect()\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:1:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch1 accuracy:\t 0.24166666666666667\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:2:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch2 accuracy:\t 0.25833333333333336\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:3:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch3 accuracy:\t 0.24722222222222223\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:4:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch4 accuracy:\t 0.2611111111111111\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:5:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch5 accuracy:\t 0.25277777777777777\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:6:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch6 accuracy:\t 0.2388888888888889\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:7:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch7 accuracy:\t 0.25277777777777777\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:8:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch8 accuracy:\t 0.2638888888888889\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:9:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch9 accuracy:\t 0.26944444444444443\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:10:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch10 accuracy:\t 0.28055555555555556\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:11:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch11 accuracy:\t 0.2638888888888889\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:12:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch12 accuracy:\t 0.29444444444444445\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:13:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch13 accuracy:\t 0.2361111111111111\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:14:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch14 accuracy:\t 0.29444444444444445\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:15:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch15 accuracy:\t 0.28888888888888886\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:16:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch16 accuracy:\t 0.37222222222222223\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:17:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch17 accuracy:\t 0.35555555555555557\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:18:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch18 accuracy:\t 0.38333333333333336\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:19:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch19 accuracy:\t 0.4166666666666667\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:20:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch20 accuracy:\t 0.41944444444444445\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:21:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch21 accuracy:\t 0.4638888888888889\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:22:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch22 accuracy:\t 0.5416666666666666\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:23:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch23 accuracy:\t 0.5944444444444444\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:24:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch24 accuracy:\t 0.6111111111111112\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:25:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch25 accuracy:\t 0.6472222222222223\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:26:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch26 accuracy:\t 0.6944444444444444\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:27:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch27 accuracy:\t 0.6972222222222222\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:28:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch28 accuracy:\t 0.7416666666666667\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:29:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch29 accuracy:\t 0.7416666666666667\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:30:   0%|          | 0/100 [00:00<?, ?g/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch30 accuracy:\t 0.775\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAeFklEQVR4nO3de7yVY/7/8ddn70qpVLT5RucUCuOwNTQYpxyKDr4GIcLIYZiiaTCI4TvEhL78DMohOYQc0uDnMGloUOxMDmWQhB0qSeeD6vP741r7127be7cP6973utd6Px+P+7HWute99/25H6vWe9/XfV/XZe6OiIjkrry4CxARkXgpCEREcpyCQEQkxykIRERynIJARCTH1Yu7gOpq2bKlt2/fPu4yREQSZebMmd+7e0F57yUuCNq3b09RUVHcZYiIJIqZfVnRe2oaEhHJcQoCEZEcpyAQEclxCgIRkRynIBARyXEKAhGRHKcgEBHJcbkTBF99BVdeCXPmxF2JiEhGiSwIzOwBM1tkZh9tZbsDzGyDmZ0UVS0AvPUW/PWv0K0b7L8/jB4NCxdGuksRkSSI8oxgHHBsZRuYWT5wM/BKhHUEp54KCxaEADCDSy+FXXaBXr1gwgRYvTryEkREMlFkQeDubwA/bGWzS4CngUVR1bGFnXaCIUOgqAhmz4Y//hE++ghOOw3+67/g7LPhtddg06Y6KafObNwIP/yQfcclImlhUU5VaWbtgefdfc9y3tsFeAw4HHggtd1TFfyewcBggLZt2+7/5ZcVDplRfZs2wRtvwMMPw8SJsGIFtG4NRx4JHTtChw7hsWPHEBZm6dt3uqxbB19/DV9+ueUyf354LC6GDRugXj1o1SqcCe2885aPpZ83bRr3EYlImpnZTHcvLPe9GINgInCru083s3FUEgSlFRYWemSDzq1eDX//Ozz6KLz3XmhKKq1RI2jffnMwdOgA7dpBixaw3XbQrFl43G472GabqofGpk2wZg2sWrV5+fHHsCxbtvl5ecs338C330LpzzEvL3ypt28f6mvXDgoKYNGicEzffLP5cdmyn9fTvDl06rT5OEs/b9MmBEomc4cPPoAXX4QXXgjPO3WCrl3DNaKuXcPSsWPmH4tImmRqEHwBlHxTtgRWA4PdfVJlvzPSIChr7drwF/W8efDFF+Gx9PPlyyv+2fr1twyGZs3CF9SqVSFwSr7wV68OIVAVjRqFL+mSpVmz8Bd+yZd9yRd/69Zh/1WxatWWwbBgQTiTmDcPPv88PP/pp83b16sX9lEShDvtBDvuuHkpKAiP228P+flVqyEdVq6EKVPCl/+LL4azIID99oPu3cNxzJkT7h4r0aAB7LbblgHRuXM4Lp0VSZbJyCAos904MuGMoDrcYenS8MWybFlYli8PS3nPly0LZwiNG8O22275WN660l/4JV/622xT98e5cWMIh88/3xwOJYE4fz58//2WZyMl8vKgZcvN4dC4cQiGvLzwWLKUfV0SoC1abD72kuclj82ahUCaO3fzX/3//CesXx++wHv2hN694bjjQlCWtmIF/Oc/IRRmzw6Pc+aEcC+toGDLZsHSz1u31pmEJE4sQWBmE4DDCH/tLwSuBeoDuPs9ZbYdR9KCQIKNG2HJEli8ODQ9lV0WLw636a5ZE5rANm7ccim7bv36EJobN1a+38aNw9kMhL/qe/cOy8EHh7/0q2vVqhAQJUFX+gzwq6/CNZYS9erBnnvC+PGw117V35dIDGI7I4iCgiAHlDShLV26+VpIyfPS6zp2DLf/duoUbT0bNoSmptIBMW5caI56+mk46qho9y+SBgoCkXQrLg5nIHPmwNixMGhQ3BWJVKqyIMidISZE0ql1a5g2DQ4/PPQ/ufba8q+ViCSAgkCkprbbLlyoPvtsuP76cFawfn3cVYlUm259EKmN+vXh/vvD9YprrglNRk8/He5uEkkInRGI1JYZXH11uIto2rRw51Lp/goiGU5BIJIuAwfCSy+Fs4Jf/jL0Tt+adetCz+fHHw93IpXuvCdSR9Q0JJJORxwBb74Zbms99FB48snwfO1a+PTTzZ3YZs8Oy+efb9lnYswYeOKJMJSHSB1REIikW7duMH06HH88nHBC6Ofw+eebR3/Nz4dddw2d0k4+OWzfrVsYCXfwYNh3X3jkETi20lHcRdJGQSAShVat4PXXYdiwMAzHKads/sLv0qX84UL22itMmnTSSWF4jKuuguuu03AWEjl1KBPJNGvWwCWXhLuRDjsMHnvs52MmiVSTOpSJJEmjRnDfffDQQ/DOO6Gp6LXX4q5KspiCQCRTnXlmCILttw8jqt5wg2aZk0goCEQyWbduIQwGDIARI8K1g8WL465KsoyCQCTTNWkSplIdMyZcgN533xAOImmiIBBJAjM477xwW2qDBqFvwty5cVclWUJBIJIk++wDr7wSRjrt3TvMzyBSSwoCkaTZdVeYNClMkvPf/60RT6XWFAQiSXTIIaGfwdSpcOGFmgtBakVdFkWSauBA+OyzcFtply5w+eVxVyQJpSAQSbI//zmEwRVXhDGNTjop7ookgdQ0JJJkZvDgg3DQQeEMQbeVSg0oCESSrmFDeO65MB5Rnz6aFEeqTUEgkg0KCuD558O8B717w/LlcVckCaIgEMkWXbvCU0/Bxx+HYa83bIi7IkkIBYFINjnqKLj77jBl5pAhuq1UqkR3DYlkm/POC9NijhoVbisdMiTuiiTDKQhEstHIkWEsoqFDYf58uPHGMM+BSDnUNCSSjfLzw8xmF18Mo0dDYSHMmhV3VZKhFAQi2apRI7jzznC9YOlS6N4dbr4ZNm6MuzLJMAoCkWx3zDHw4Yehj8EVV8Dhh4fmIpEUBYFILthhB5g4McyDPGsW7L03jB+vu4oEUBCI5A6zMA/yBx+EeQ3OOgt+8xtYsiTuyiRmCgKRXNO+fRi+euRImDwZ9toLXn457qokRgoCkVyUnx+GrZ4xA1q0gGOPDfMaaGiKnKQgEMll++4LM2fCZZfBmDHQrVsYs0hyioJAJNc1bAi33gpvvw3Nm8MJJ8CAAbBoUdyVSR1REIhI0L17ODu4/np45hnYYw94+GHdWZQDFAQislmDBnDNNfDvf8Puu4e7jI47Dr78Mu7KJEKRBYGZPWBmi8zsowreP93MPjCzD83sLTP7RVS1iEg1de0K06aFnslvvhmuHdxxh3olZ6kozwjGAcdW8v4XwK/dfS/gBmBMhLWISHXl5YWximbPhkMPDaOYHnwwzJkTd2WSZpEFgbu/AfxQyftvufvS1MvpQOuoahGRWmjbFl54IVwv+OyzcC3h++/jrkrSKFOuEZwL/N+K3jSzwWZWZGZFixcvrsOyRAQIvZLPOCMMYLdqFfz973FXJGkUexCY2eGEILi8om3cfYy7F7p7YUFBQd0VJyJb2n//cIYwaVLclUgaxRoEZrY3cB/Q19014IlIpjODfv3glVfCmYFkhdiCwMzaAs8AA93907jqEJFq6tcP1q7V+ERZJMrbRycAbwO7mVmxmZ1rZheY2QWpTUYAOwB/M7NZZlYUVS0ikkaHHALbb6/moSwS2ZzF7j5gK+//FvhtVPsXkYjUqxeGoZg8GX76CerXj7siqaXYLxaLSAL17x+mv3zjjbgrkTRQEIhI9fXsGeZEVvNQVlAQiEj1bbttmAt50iQNSpcFFAQiUjP9+0NxcRixVBJNQSAiNdO7d5jpTM1DiacgEJGa2WGHMBjds8/GXYnUkoJARGquf/8wGumn6hOaZAoCEam5vn3D43PPxVuH1IqCQERqrm1b2G8/NQ8lnIJARGqnf3+YPh2+/TbuSqSGFAQiUjv9+oW+BJqjILEUBCJSO926QadOah5KMAWBiNSOWWgemjIFli+PuxqpAQWBiNRev35hJNIXX4y7EqkBBYGI1N6BB8KOO6qXcUIpCESk9vLzQ5+CF1+EdevirkaqSUEgIunRrx+sWAGvvRZ3JVJNCgIRSY8jjoAmTdQ8lEAKAhFJj4YNoVevMNzEpk1xVyPVoCAQkfTp1w8WLgw9jSUxFAQikj69eoXJ7NU8lCgKAhFJn2bNwrWCZ5/VFJYJoiAQkfTq1w/mzg3zFEgiKAhEJL369AmPah5KDAWBiKTXzjuHnsYahC4xFAQikn79+sHMmfD113FXIlWgIBCR9OvXLzyqeSgRFAQikn677Qb77AP33qu7hxJAQSAi0Rg6FGbPhpdfjrsS2QoFgYhEY8CAcOH41lvjrkS2QkEgItFo0AAuuQT+8Q94//24q5FKKAhEJDrnnw+NG+usIMMpCEQkOi1awDnnwIQJsGBB3NVIBRQEIhKtoUPDsNR33hl3JVIBBYGIRKtjRzjxxHAr6cqVcVcj5VAQiEj0hg2DH3+EBx6IuxIph4JARKJ34IHQowfcfjts2BB3NVJGlYLAzIaY2XYW3G9m75nZ0VEXJyJZZNgwmD9fg9FloKqeEZzj7suBo4EWwEBgZGU/YGYPmNkiM/uogvfNzO4ws7lm9oGZ7VetykUkWfr2hU6dwq2kGnYio1Q1CCz12At42N1nl1pXkXHAsZW8fxzQObUMBu6uYi0ikkT5+XDppTBjBrz1VtzVSClVDYKZZvYKIQheNrOmwKbKfsDd3wB+qGSTvsB4D6YDzc2sVRXrEZEkGjQo9C0YNSruSqSUqgbBucAVwAHuvhqoD5xdy33vApQerLw4te5nzGywmRWZWdHixYtruVsRiU3jxnDhhfDcc/DZZ3FXIylVDYKDgE/c/UczOwO4GlgWXVlbcvcx7l7o7oUFBQV1tVsRicLFF0P9+jB6dNyVSEpVg+BuYLWZ/QIYBnwOjK/lvhcAbUq9bp1aJyLZrFUrOP10ePBBWLIk7mqEqgfBBnd3Qrv+/3H3u4Cmtdz3ZODM1N1DBwLL3P3bWv5OEUmCyy6DNWvgnnvirkSoehCsMLMrCbeNvmBmeYTrBBUyswnA28BuZlZsZuea2QVmdkFqkxeBecBcYCxwUY2OQESSZ8894ZhjwvhDa9fGXU3Oq1fF7U4BTiP0J/jOzNoCf63sB9x9wFbed+B3Vdy/iGSbYcPg6KPhscfCCKUSmyqdEbj7d8CjQDMzOx5Y6+61vUYgIrnsqKNg773httvUwSxmVR1i4mTgHeA3wMnADDM7KcrCRCTLmYWzAs1rHDvzKiSxmb0P9HT3RanXBcA/3P0XEdf3M4WFhV5UVFTXuxWRKKxfDx06wB57hCktJTJmNtPdC8t7r6oXi/NKQiBlSTV+VkSkfA0awJAhMGUKTJsWdzU5q6pf5i+Z2ctmNsjMBgEvEO76ERGpnYsvhtatN89kJnWuqheLhwNjgL1Tyxh3vzzKwkQkR2y7Ldx8M7z3HozXPShxqNI1gkyiawQiWcg9TFwzf34Yg6hJk7gryjo1vkZgZivMbHk5ywozWx5NuSKSc8zC7GXffQcjK53qRCJQaRC4e1N3366cpam7b1dXRYpIDjjwwDAG0ahR8OWXcVeTU3Tnj4hkjptugrw8uOKKuCvJKQoCEckcbdrAH/8Ijz+uWczqkIJARDLL8OGw8866nbQOKQhEJLM0bhwuGL/7Ljz6aNzV5AQFgYhkntNPhwMOCNcKVq2Ku5qspyAQkcyTlxemsvzmG7jllriryXoKAhHJTD16wKmnhiD46qu4q8lqCgIRyVwlncuuvDLeOrKcgkBEMle7dvCHP4RZzN5+O+5qspaCQEQy2+WXQ6tWcOmlup00IgoCEclsTZqEHsczZsCECXFXk5UUBCKS+QYOhP33D7eTrl4ddzVZR0EgIpmv5HbS4mI1EUVAQSAiyXDwwWEcojFj4JxzYMOGuCvKGvXiLkBEpMpGjgzXDEaMgOXLwzWDbbaJu6rE0xmBiCSHGVxzDfzv/8Kzz8IJJ2gIijRQEIhI8vz+9zBuHEyZAj17wtKlcVeUaAoCEUmms86CiROhqAgOPxwWLoy7osRSEIhIcp14Ijz/fJjw/pBDNCZRDSkIRCTZjj4aXn0VFi0KdxZ98kncFSWOgkBEkq9HD/jnP2HdunBmMGtW3BUlioJARLLDPvvAtGnQsCEcdhi8+WbcFSWGgkBEskeXLvCvf8GOO8Ixx8C8eXFXlAgKAhHJLm3bhttK8/LgwgvBPe6KMp6CQESyT5s28Je/wCuvaMTSKlAQiEh2uugi6N4dhg6FJUviriajKQhEJDvl58PYsaHX8fDhcVeT0RQEIpK99t4bhg2DBx+EqVPjriZjKQhEJLuNGAEdO8L558PatXFXk5EiDQIzO9bMPjGzuWZ2RTnvtzWzqWb2bzP7wMx6RVmPiOSgbbeFe+4Jw1DceGPc1WSkyILAzPKBu4DjgK7AADPrWmazq4En3X1f4FTgb1HVIyI5rGdPOOOMMJ/BnDlxV5Nxojwj6A7Mdfd57r4eeBzoW2YbB7ZLPW8GfBNhPSKSy267DZo2hcGDNdVlGVEGwS7A16VeF6fWlXYdcIaZFQMvApeU94vMbLCZFZlZ0eLFi6OoVUSyXUEB3HprGHpi7Ni4q8kocV8sHgCMc/fWQC/gYTP7WU3uPsbdC929sKCgoM6LFJEscdZZYe6Cyy+Hb7+Nu5qMEWUQLADalHrdOrWutHOBJwHc/W2gIdAywppEJJeZwb33hruHhgyJu5qMEWUQvAt0NrMOZtaAcDF4cpltvgKOBDCzPQhBoLYfEYlO585h3uOJE8OkNhJdELj7BuBi4GXgY8LdQbPN7Hoz65PabBhwnpm9D0wABrlrhCgRidjw4dCtWxiGYuXKuKuJnSXte7ewsNCLioriLkNEku6tt+BXvwpjEd1+e9zVRM7MZrp7YXnvxX2xWEQkHj16hGGq77gDcvyPSwWBiOSum26CnXaCc8+F9evjriY2CgIRyV3NmoXhJz74AG6+Oe5qYqMgEJHc1qcPnHoq3HADfPRR3NXEQkEgInLHHeHs4JxzYMOGuKupcwoCEZGCArjzTnj3XRg9Ou5q6pyCQEQE4JRToG/f0Nns00/jrqZOKQhERCAMP/G3v8E228Bvf5tTI5QqCERESuy8c+hcNm0a3H133NXUGQWBiEhpgwbB0UeHEUrnz4+7mjqhIBARKc0MxowJj4MHQ8KG4akJBYGISFnt2oUOZq++CuPGxV1N5BQEIiLlueACOPRQuPRS+Ca7Z9FVEIiIlCcvD+67D9atC4PTZXETkYJARKQinTuHoScmT4Ynnoi7msgoCEREKnPppdC9O1xyCSzOzgkUFQQiIpXJz4cHHoBly0IYZCEFgYjI1nTrBldfHZqHpk6Nu5q0UxCIiFTF8OHQti384Q9ZN/yEgkBEpCoaNYIbb4T33oNHH427mrRSEIiIVNWAAVBYCH/6E6xZE3c1aaMgEBGpqrw8GDUKiovD4HRZQkEgIlIdv/51mLfgpptg4cK4q0kLBYGISHXdcgusXQvXXRd3JWmhIBARqa4uXcJYRGPHwpw5cVdTawoCEZGaGDECGjcO8xYknIJARKQmCgrgqqvg+efhtdfirqZWFAQiIjX1+9+HuQuGDUt0JzMFgYhITTVsGO4emjULHnkk7mpqTEEgIlIbp5wCBxwQOpmtXh13NTWiIBARqY28PLj1VliwILGdzBQEIiK1dcgh0L8/jByZyE5mCgIRkXQYOTJ0Mrv22rgrqTYFgYhIOnTpAhddlMhOZgoCEZF0ueYaaNo0zF1QEXdYtSpcU5g9G775pu7qq0C9uAsQEckaLVuGmcyGD4fTToOffoIff/z5smHD5p+pVw9eegmOPDKWkgHM3WPbeU0UFhZ6UVFR3GWIiJRv7dowQunXX0Pz5j9fmjXb8vn//E84O3j7bdh998jKMrOZ7l5Y3ns6IxARSaeGDWHGjKpv/8tfhuX442H69HBWUccivUZgZsea2SdmNtfMrqhgm5PNbI6ZzTazx6KsR0Qk43ToAM89Fya7OfFEWLeuzkuILAjMLB+4CzgO6AoMMLOuZbbpDFwJ/MrduwFDo6pHRCRjHXQQPPggTJsGgweHC8p1KMozgu7AXHef5+7rgceBvmW2OQ+4y92XArj7ogjrERHJXAMGhIluxo8P4xfVoSiDYBfg61Kvi1PrSusCdDGzN81supkdW94vMrPBZlZkZkWLFy+OqFwRkZiNGBHuNrrqKpg4sc52G3c/gnpAZ+AwYAAw1syal93I3ce4e6G7FxYUFNRthSIidcUM7r8fevSAM8+Ed96pk91GGQQLgDalXrdOrSutGJjs7j+5+xfAp4RgEBHJTQ0bwqRJ0KoV9OkDX30V+S6jDIJ3gc5m1sHMGgCnApPLbDOJcDaAmbUkNBXNi7AmEZHMV1AQZj5bsybcVrpiRaS7iywI3H0DcDHwMvAx8KS7zzaz682sT2qzl4ElZjYHmAoMd/clUdUkIpIYXbvCU0+FcYsGDICNGyPblXoWi4hksnvugQsvhCFDYPToGv8a9SwWEUmqCy6ATz4JIVAywmmaKQhERDLdqFFhwps2bba+bQ0oCEREMl1+PjwW3Qg8cfcjEBGRmCkIRERynIJARCTHKQhERHKcgkBEJMcpCEREcpyCQEQkxykIRERyXOLGGjKzxcCXZVa3BL6PoZyoZNvxQPYdU7YdD2TfMWXb8UDtjqmdu5c7oUvigqA8ZlZU0WBKSZRtxwPZd0zZdjyQfceUbccD0R2TmoZERHKcgkBEJMdlSxCMibuANMu244HsO6ZsOx7IvmPKtuOBiI4pK64RiIhIzWXLGYGIiNSQgkBEJMclOgjM7Fgz+8TM5prZFXHXkw5mNt/MPjSzWWaWyMmZzewBM1tkZh+VWre9mb1qZp+lHlvEWWN1VHA815nZgtTnNMvMesVZY3WYWRszm2pmc8xstpkNSa1P8mdU0TEl8nMys4Zm9o6ZvZ86nj+n1ncwsxmp77wnzKxBWvaX1GsEZpYPfAr0BIqBd4EB7j4n1sJqyczmA4XuntiOMGZ2KLASGO/ue6bW3QL84O4jU6Hdwt0vj7POqqrgeK4DVrr7qDhrqwkzawW0cvf3zKwpMBPoBwwiuZ9RRcd0Mgn8nMzMgMbuvtLM6gP/AoYAlwHPuPvjZnYP8L67313b/SX5jKA7MNfd57n7euBxoG/MNQng7m8AP5RZ3Rd4KPX8IcJ/0kSo4HgSy92/dff3Us9XAB8Du5Dsz6iiY0okD1amXtZPLQ4cATyVWp+2zyjJQbAL8HWp18Uk+IMvxYFXzGymmQ2Ou5g02sndv009/w7YKc5i0uRiM/sg1XSUmGaU0sysPbAvMIMs+YzKHBMk9HMys3wzmwUsAl4FPgd+dPcNqU3S9p2X5CDIVge7+37AccDvUs0SWcVDe2Qy2yQ3uxvoBOwDfAvcGms1NWBmTYCngaHuvrz0e0n9jMo5psR+Tu6+0d33AVoTWkB2j2pfSQ6CBUCbUq9bp9YlmrsvSD0uAp4l/APIBgtT7bgl7bmLYq6nVtx9Yeo/6iZgLAn7nFLtzk8Dj7r7M6nVif6MyjumpH9OAO7+IzAVOAhobmb1Um+l7TsvyUHwLtA5dRW9AXAqMDnmmmrFzBqnLnRhZo2Bo4GPKv+pxJgMnJV6fhbwXIy11FrJF2ZKfxL0OaUuRN4PfOzut5V6K7GfUUXHlNTPycwKzKx56nkjwk0xHxMC4aTUZmn7jBJ71xBA6law0UA+8IC7/yXeimrHzDoSzgIA6gGPJfGYzGwCcBhhyNyFwLXAJOBJoC1hGPGT3T0RF2ArOJ7DCM0NDswHzi/Vvp7RzOxgYBrwIbAptfpPhDb1pH5GFR3TABL4OZnZ3oSLwfmEP9ifdPfrU98RjwPbA/8GznD3dbXeX5KDQEREai/JTUMiIpIGCgIRkRynIBARyXEKAhGRHKcgEBHJcQoCkTpkZoeZ2fNx1yFSmoJARCTHKQhEymFmZ6TGg59lZvemBgBbaWa3p8aHn2JmBalt9zGz6amBzZ4tGdjMzHY1s3+kxpR/z8w6pX59EzN7ysz+Y2aPpnrFisRGQSBShpntAZwC/Co16NdG4HSgMVDk7t2A1wk9jAHGA5e7+96Enq0l6x8F7nL3XwA9CIOeQRgZcyjQFegI/CriQxKpVL2tbyKSc44E9gfeTf2x3ogwANsm4InUNo8Az5hZM6C5u7+eWv8QMDE1ZtQu7v4sgLuvBUj9vnfcvTj1ehbQnjDxiEgsFAQiP2fAQ+5+5RYrza4ps11Nx2cpPTbMRvT/UGKmpiGRn5sCnGRmO8L/n8u3HeH/S8nIj6cB/3L3ZcBSMzsktX4g8HpqlqxiM+uX+h3bmNm2dXkQIlWlv0REynD3OWZ2NWGmuDzgJ+B3wCqge+q9RYTrCBCGA74n9UU/Dzg7tX4gcK+ZXZ/6Hb+pw8MQqTKNPipSRWa20t2bxF2HSLqpaUhEJMfpjEBEJMfpjEBEJMcpCEREcpyCQEQkxykIRERynIJARCTH/T/r3JGrEf93bwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "#training\n",
    "use_gpu = torch.cuda.is_available()\n",
    "torch.backends.cudnn.enabled = True\n",
    "torch.backends.cudnn.benchmark = True\n",
    "model = JigsawNet()\n",
    "if use_gpu:\n",
    "    model.cuda()\n",
    "# 清理内存\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()\n",
    "loss_func = torch.nn.CrossEntropyLoss().cuda()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "# 可变学习率\n",
    "#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.0003)\n",
    "loss_count = []\n",
    "acc_count = []\n",
    "Epoches = 30\n",
    "for epoch in range(Epoches):\n",
    "    train(data_loader_train,loss_count)\n",
    "    #scheduler.step()\n",
    "torch.save(model.state_dict(),'./cnn_jigsaw.pth' )\n",
    "plt.plot(range(1,Epoches+1),loss_count,c='r')\n",
    "plt.xlabel('epoch')\n",
    "plt.ylabel('loss')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
